{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2022-07-17T07:22:26.447801Z",
     "iopub.status.busy": "2022-07-17T07:22:26.447096Z",
     "iopub.status.idle": "2022-07-17T07:22:27.210138Z",
     "shell.execute_reply": "2022-07-17T07:22:27.209532Z",
     "shell.execute_reply.started": "2022-07-17T07:22:26.447544Z"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "# pandas 数据集读取，dataframe形式的\n",
    "import pandas as pd\n",
    "# 文件读取\n",
    "import codecs\n",
    "\n",
    "train_df = pd.read_csv('./基于论文摘要的文本分类与查询性问答公开数据/train.csv', sep=',')\n",
    "test_df = pd.read_csv('./基于论文摘要的文本分类与查询性问答公开数据/test.csv', sep=',')\n",
    "train_df = train_df[~train_df['Topic(Label)'].isnull()]\n",
    "train_df['Topic(Label)'], lbl = pd.factorize(train_df['Topic(Label)'])\n",
    "\n",
    "train_df['Title'] = train_df['Title'].apply(lambda x: x.strip())\n",
    "train_df['Abstract'] = train_df['Abstract'].fillna(\n",
    "    '').apply(lambda x: x.strip())\n",
    "train_df['text'] = train_df['Title'] + ' ' + train_df['Abstract']\n",
    "train_df['text'] = train_df['text'].str.lower()\n",
    "\n",
    "test_df['Title'] = test_df['Title'].apply(lambda x: x.strip())\n",
    "test_df['Abstract'] = test_df['Abstract'].fillna('').apply(lambda x: x.strip())\n",
    "test_df['text'] = test_df['Title'] + ' ' + test_df['Abstract']\n",
    "test_df['text'] = test_df['text'].str.lower()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-03-11T09:35:08.358976Z",
     "start_time": "2021-03-11T09:35:08.347065Z"
    },
    "execution": {
     "iopub.execute_input": "2022-07-17T07:22:27.211619Z",
     "iopub.status.busy": "2022-07-17T07:22:27.211388Z",
     "iopub.status.idle": "2022-07-17T07:22:27.811932Z",
     "shell.execute_reply": "2022-07-17T07:22:27.811438Z",
     "shell.execute_reply.started": "2022-07-17T07:22:27.211602Z"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "import torch\n",
    "from sklearn.model_selection import train_test_split\n",
    "from torch.utils.data import Dataset, DataLoader, TensorDataset\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import random\n",
    "import re"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-03-11T09:35:19.935035Z",
     "start_time": "2021-03-11T09:35:09.908638Z"
    },
    "execution": {
     "iopub.execute_input": "2022-07-17T07:22:27.818257Z",
     "iopub.status.busy": "2022-07-17T07:22:27.818146Z",
     "iopub.status.idle": "2022-07-17T07:22:39.959290Z",
     "shell.execute_reply": "2022-07-17T07:22:39.958821Z",
     "shell.execute_reply.started": "2022-07-17T07:22:27.818242Z"
    },
    "tags": []
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/lyz/.local/lib/python3.9/site-packages/requests/__init__.py:102: RequestsDependencyWarning: urllib3 (1.26.8) or chardet (5.0.0)/charset_normalizer (2.0.11) doesn't match a supported version!\n",
      "  warnings.warn(\"urllib3 ({}) or chardet ({})/charset_normalizer ({}) doesn't match a supported \"\n"
     ]
    }
   ],
   "source": [
    "# pip install transformers\n",
    "# transformers bert相关的模型使用和加载\n",
    "from transformers import AutoTokenizer\n",
    "# 分词器，词典\n",
    "\n",
    "tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')\n",
    "train_encoding = tokenizer(train_df['text'].to_list()[:], truncation=True, padding=True, max_length=512)\n",
    "test_encoding = tokenizer(test_df['text'].to_list()[:], truncation=True, padding=True, max_length=512)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-03-11T09:35:43.578135Z",
     "start_time": "2021-03-11T09:35:43.571452Z"
    },
    "execution": {
     "iopub.execute_input": "2022-07-17T07:22:39.961590Z",
     "iopub.status.busy": "2022-07-17T07:22:39.961445Z",
     "iopub.status.idle": "2022-07-17T07:22:39.966612Z",
     "shell.execute_reply": "2022-07-17T07:22:39.966277Z",
     "shell.execute_reply.started": "2022-07-17T07:22:39.961573Z"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "# 数据集读取\n",
    "class XunFeiDataset(Dataset):\n",
    "    def __init__(self, encodings, labels):\n",
    "        self.encodings = encodings\n",
    "        self.labels = labels\n",
    "\n",
    "    # 读取单个样本\n",
    "    def __getitem__(self, idx):\n",
    "        item = {key: torch.tensor(val[idx])\n",
    "                for key, val in self.encodings.items()}\n",
    "        item['labels'] = torch.tensor(int(self.labels[idx]))\n",
    "        return item\n",
    "\n",
    "    def __len__(self):\n",
    "        return len(self.labels)\n",
    "\n",
    "\n",
    "train_dataset = XunFeiDataset(train_encoding, train_df['Topic(Label)'].to_list())\n",
    "test_dataset = XunFeiDataset(test_encoding, [0] * len(test_df))\n",
    "\n",
    "# 单个读取到批量读取\n",
    "train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)\n",
    "test_dataloader = DataLoader(test_dataset, batch_size=8, shuffle=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-03-11T09:35:44.110121Z",
     "start_time": "2021-03-11T09:35:44.104871Z"
    },
    "execution": {
     "iopub.execute_input": "2022-07-17T07:22:43.423746Z",
     "iopub.status.busy": "2022-07-17T07:22:43.423167Z",
     "iopub.status.idle": "2022-07-17T07:22:43.430529Z",
     "shell.execute_reply": "2022-07-17T07:22:43.429752Z",
     "shell.execute_reply.started": "2022-07-17T07:22:43.423695Z"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "# 精度计算\n",
    "def flat_accuracy(preds, labels):\n",
    "    pred_flat = np.argmax(preds, axis=1).flatten()\n",
    "    labels_flat = labels.flatten()\n",
    "    return np.sum(pred_flat == labels_flat) / len(labels_flat)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-03-11T09:58:49.027161Z",
     "start_time": "2021-03-11T09:58:45.317009Z"
    },
    "execution": {
     "iopub.execute_input": "2022-07-17T07:22:44.405381Z",
     "iopub.status.busy": "2022-07-17T07:22:44.404963Z",
     "iopub.status.idle": "2022-07-17T07:22:51.497373Z",
     "shell.execute_reply": "2022-07-17T07:22:51.496981Z",
     "shell.execute_reply.started": "2022-07-17T07:22:44.405334Z"
    },
    "tags": []
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForSequenceClassification: ['cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.bias', 'cls.predictions.transform.dense.weight']\n",
      "- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
      "- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
      "Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.weight', 'classifier.bias']\n",
      "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n",
      "/home/lyz/.local/lib/python3.9/site-packages/transformers/optimization.py:306: FutureWarning: This implementation of AdamW is deprecated and will be removed in a future version. Use thePyTorch implementation torch.optim.AdamW instead, or set `no_deprecation_warning=True` to disable this warning\n",
      "  warnings.warn(\n"
     ]
    }
   ],
   "source": [
    "from transformers import AutoModelForSequenceClassification, AdamW, get_linear_schedule_with_warmup\n",
    "model = AutoModelForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=12)\n",
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "model.to(device)\n",
    "\n",
    "# 优化方法\n",
    "optim = AdamW(model.parameters(), lr=1e-5)\n",
    "total_steps = len(train_loader) * 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-03-11T09:44:22.077501Z",
     "start_time": "2021-03-11T09:39:16.473609Z"
    },
    "execution": {
     "iopub.execute_input": "2022-07-17T07:22:53.407167Z",
     "iopub.status.busy": "2022-07-17T07:22:53.406549Z",
     "iopub.status.idle": "2022-07-17T07:43:24.534800Z",
     "shell.execute_reply": "2022-07-17T07:43:24.534169Z",
     "shell.execute_reply.started": "2022-07-17T07:22:53.407116Z"
    },
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "------------Epoch: 0 ----------------\n",
      "epoth: 0, iter_num: 100, loss: 2.2585, 4.85%\n",
      "epoth: 0, iter_num: 200, loss: 1.7927, 9.69%\n",
      "epoth: 0, iter_num: 300, loss: 0.9374, 14.54%\n",
      "epoth: 0, iter_num: 400, loss: 0.4790, 19.39%\n",
      "epoth: 0, iter_num: 500, loss: 0.5791, 24.24%\n",
      "epoth: 0, iter_num: 600, loss: 0.3968, 29.08%\n",
      "epoth: 0, iter_num: 700, loss: 0.2969, 33.93%\n",
      "epoth: 0, iter_num: 800, loss: 0.3103, 38.78%\n",
      "epoth: 0, iter_num: 900, loss: 0.0524, 43.63%\n",
      "epoth: 0, iter_num: 1000, loss: 0.0366, 48.47%\n",
      "epoth: 0, iter_num: 1100, loss: 0.2020, 53.32%\n",
      "epoth: 0, iter_num: 1200, loss: 0.0448, 58.17%\n",
      "epoth: 0, iter_num: 1300, loss: 0.6537, 63.02%\n",
      "epoth: 0, iter_num: 1400, loss: 0.0215, 67.86%\n",
      "epoth: 0, iter_num: 1500, loss: 0.8063, 72.71%\n",
      "epoth: 0, iter_num: 1600, loss: 0.0250, 77.56%\n",
      "epoth: 0, iter_num: 1700, loss: 0.1452, 82.40%\n",
      "epoth: 0, iter_num: 1800, loss: 0.2006, 87.25%\n",
      "epoth: 0, iter_num: 1900, loss: 0.0123, 92.10%\n",
      "epoth: 0, iter_num: 2000, loss: 0.8850, 96.95%\n",
      "Epoch: 0, Average training loss: 0.6066\n",
      "------------Epoch: 1 ----------------\n",
      "epoth: 1, iter_num: 100, loss: 0.0326, 4.85%\n",
      "epoth: 1, iter_num: 200, loss: 0.0142, 9.69%\n",
      "epoth: 1, iter_num: 300, loss: 0.5070, 14.54%\n",
      "epoth: 1, iter_num: 400, loss: 0.0405, 19.39%\n",
      "epoth: 1, iter_num: 500, loss: 0.3428, 24.24%\n",
      "epoth: 1, iter_num: 600, loss: 0.0330, 29.08%\n",
      "epoth: 1, iter_num: 700, loss: 0.8683, 33.93%\n",
      "epoth: 1, iter_num: 800, loss: 0.0227, 38.78%\n"
     ]
    },
    {
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m                         Traceback (most recent call last)",
      "Input \u001b[0;32mIn [7]\u001b[0m, in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m     62\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m epoch \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mrange\u001b[39m(\u001b[38;5;241m3\u001b[39m):\n\u001b[1;32m     63\u001b[0m     \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m------------Epoch: \u001b[39m\u001b[38;5;132;01m%d\u001b[39;00m\u001b[38;5;124m ----------------\u001b[39m\u001b[38;5;124m\"\u001b[39m \u001b[38;5;241m%\u001b[39m epoch)\n\u001b[0;32m---> 64\u001b[0m     \u001b[43mtrain\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n",
      "Input \u001b[0;32mIn [7]\u001b[0m, in \u001b[0;36mtrain\u001b[0;34m()\u001b[0m\n\u001b[1;32m      7\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m batch \u001b[38;5;129;01min\u001b[39;00m train_loader:\n\u001b[1;32m      8\u001b[0m     \u001b[38;5;66;03m# 正向传播\u001b[39;00m\n\u001b[1;32m      9\u001b[0m     optim\u001b[38;5;241m.\u001b[39mzero_grad()\n\u001b[0;32m---> 11\u001b[0m     input_ids \u001b[38;5;241m=\u001b[39m \u001b[43mbatch\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43minput_ids\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m]\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mto\u001b[49m\u001b[43m(\u001b[49m\u001b[43mdevice\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m     12\u001b[0m     attention_mask \u001b[38;5;241m=\u001b[39m batch[\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mattention_mask\u001b[39m\u001b[38;5;124m'\u001b[39m]\u001b[38;5;241m.\u001b[39mto(device)\n\u001b[1;32m     13\u001b[0m     labels \u001b[38;5;241m=\u001b[39m batch[\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mlabels\u001b[39m\u001b[38;5;124m'\u001b[39m]\u001b[38;5;241m.\u001b[39mto(device)\n",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m: "
     ]
    }
   ],
   "source": [
    "# 训练函数\n",
    "def train():\n",
    "    model.train()\n",
    "    total_train_loss = 0\n",
    "    iter_num = 0\n",
    "    total_iter = len(train_loader)\n",
    "    for batch in train_loader:\n",
    "        # 正向传播\n",
    "        optim.zero_grad()\n",
    "\n",
    "        input_ids = batch['input_ids'].to(device)\n",
    "        attention_mask = batch['attention_mask'].to(device)\n",
    "        labels = batch['labels'].to(device)\n",
    "        outputs = model(input_ids, attention_mask=attention_mask, labels=labels)\n",
    "        loss = outputs[0]\n",
    "        total_train_loss += loss.item()\n",
    "\n",
    "        # 反向梯度信息\n",
    "        loss.backward()\n",
    "        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)\n",
    "\n",
    "        # 参数更新\n",
    "        optim.step()\n",
    "        # scheduler.step()\n",
    "\n",
    "        iter_num += 1\n",
    "        if(iter_num % 100 == 0):\n",
    "            print(\"epoth: %d, iter_num: %d, loss: %.4f, %.2f%%\" %\n",
    "                  (epoch, iter_num, loss.item(), iter_num/total_iter*100))\n",
    "\n",
    "    print(\"Epoch: %d, Average training loss: %.4f\" %\n",
    "          (epoch, total_train_loss/len(train_loader)))\n",
    "\n",
    "\n",
    "def validation():\n",
    "    model.eval()\n",
    "    total_eval_accuracy = 0\n",
    "    total_eval_loss = 0\n",
    "    for batch in test_dataloader:\n",
    "        with torch.no_grad():\n",
    "            # 正常传播\n",
    "            input_ids = batch['input_ids'].to(device)\n",
    "            attention_mask = batch['attention_mask'].to(device)\n",
    "            labels = batch['labels'].to(device)\n",
    "            outputs = model(\n",
    "                input_ids, attention_mask=attention_mask, labels=labels)\n",
    "\n",
    "        loss = outputs[0]\n",
    "        logits = outputs[1]\n",
    "\n",
    "        total_eval_loss += loss.item()\n",
    "        logits = logits.detach().cpu().numpy()\n",
    "        label_ids = labels.to('cpu').numpy()\n",
    "        total_eval_accuracy += flat_accuracy(logits, label_ids)\n",
    "\n",
    "    avg_val_accuracy = total_eval_accuracy / len(test_dataloader)\n",
    "    print(\"Accuracy: %.4f\" % (avg_val_accuracy))\n",
    "    print(\"Average testing loss: %.4f\" %\n",
    "          (total_eval_loss/len(test_dataloader)))\n",
    "    print(\"-------------------------------\")\n",
    "\n",
    "for epoch in range(2):\n",
    "    print(\"------------Epoch: %d ----------------\" % epoch)\n",
    "    train()\n",
    "    # validation()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2022-07-17T07:43:28.106779Z",
     "iopub.status.busy": "2022-07-17T07:43:28.106358Z",
     "iopub.status.idle": "2022-07-17T07:43:28.112804Z",
     "shell.execute_reply": "2022-07-17T07:43:28.112201Z",
     "shell.execute_reply.started": "2022-07-17T07:43:28.106731Z"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "def prediction():\n",
    "    model.eval()\n",
    "    test_label = []\n",
    "    for batch in test_dataloader:\n",
    "        with torch.no_grad():\n",
    "            input_ids = batch['input_ids'].to(device)\n",
    "            attention_mask = batch['attention_mask'].to(device)\n",
    "\n",
    "            pred = model(input_ids, attention_mask).logits\n",
    "            test_label += list(pred.argmax(1).data.cpu().numpy())\n",
    "    return test_label"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2022-07-17T07:43:29.086235Z",
     "iopub.status.busy": "2022-07-17T07:43:29.085629Z",
     "iopub.status.idle": "2022-07-17T07:44:05.383213Z",
     "shell.execute_reply": "2022-07-17T07:44:05.382780Z",
     "shell.execute_reply.started": "2022-07-17T07:43:29.086185Z"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "test_predict = prediction()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2022-07-17T07:44:05.384270Z",
     "iopub.status.busy": "2022-07-17T07:44:05.384041Z",
     "iopub.status.idle": "2022-07-17T07:44:05.393027Z",
     "shell.execute_reply": "2022-07-17T07:44:05.392579Z",
     "shell.execute_reply.started": "2022-07-17T07:44:05.384253Z"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "test_df['Topic(Label)'] = [lbl[x] for x in test_predict]\n",
    "test_df[['Topic(Label)']].to_csv('bert_submit3.csv', index=None)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3.10"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.10"
  },
  "toc": {
   "base_numbering": 1,
   "nav_menu": {},
   "number_sections": true,
   "sideBar": true,
   "skip_h1_title": false,
   "title_cell": "Table of Contents",
   "title_sidebar": "Contents",
   "toc_cell": false,
   "toc_position": {},
   "toc_section_display": true,
   "toc_window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
